---
title: "Debug R Effective"
output: html_notebook
author: Eric Marty
---
  
```{r setup, include=FALSE, echo=FALSE, warning=FALSE, message=FALSE}
knitr::opts_chunk$set(include=TRUE, echo=FALSE, warning=FALSE, message=FALSE, out.width = "100%")
library(here)
library(tidyverse)
```

```{r functions}
# This fuction takes the dates and the dataframe of parameters in natural units, 
# and returns a vector for detection probability (q).
# dates <- seq(min(dates),max(dates),1)
# t <- seq_along(dates)

get_q <- function(t, params) {
  
  # detection probability # check against pomp code
  q_min <- params$base_detect_frac
  q_max <- params$max_detect_frac
  q_n <- params$detect_rampup # Hill coefficient
  q_th <- params$t_half_detect 
  # as defined in states-model.rmd:
  q <- scales::rescale(
    t^q_n / (q_th^q_n + t^q_n), 
    from = c(0,1), 
    to = c(q_min, q_max)
  )
  # as defined in makepompmodel:
  # 1/(1+exp(max_detect_par)) * exp(log_detect_inc_rate)^t / (exp(log_detect_inc_rate)^exp(log_half_detect) + exp(log_detect_inc_rate)^t) + base_detect_frac
  # q <- 1/(1 + q_max) * q_n ^ t / (q_n ^ q_th + q_n ^ t) + q_min # does not jive with q caluclation above
  
  
  # Eamon
  # 1/(1+exp(max_detect_par)) * (t ^ exp(log_detect_inc_rate))  / ( (exp(log_half_detect) ^ exp(log_detect_inc_rate)) + (t ^ exp(log_detect_inc_rate))) + exp(base_detect_frac)
  
  
  return(q)
}

# This fuction takes the dates and the dataframe of parameters in natural units, 
# and returns a vector for diagnosis rate (s).

get_s <- function(t, params) {
  
  # diagnosis rate 
  s_max <- params$max_diag_factor
  s_n <- params$diag_rampup # Hill coefficient
  s_th <- params$t_half_diag 
  # s <- scales::rescale(
  #   t^s_n / (s_th^s_n + t^s_n),
  #   from = c(0,1),
  #   to = c(s_min, s_max)
  #   )
  
  # as defined in makepompmodel:
  # 1 + exp(log_max_diag) * exp(log_diag_inc_rate)^t / ( exp(log_diag_inc_rate)^exp(log_half_diag) +   exp(log_diag_inc_rate)^t)
  s <- 1 + s_max * s_n^t / ( s_n ^ s_th + s_n ^ t )
  
  return(s)
}
  
# This fuction takes the dates, number of susceptibles, population, omega, 
# and the dataframe of parameters in natural units, and returns a vector for R effective.
# Transition times between classes are expected to be per class, and not per sub-compartment.

#getReff <- function(S, N=1, omega, q, s, params) {
getReff <- function(S, omega, q, s, params) {
  
  # variables
  # S <- S/N # susceptible fraction
  S <- S
  
  # constants
  a <- params$frac_asym
  h <- params$frac_hosp

  # relative transmissibility
  b_L <- params$frac_trans_e
  b_I_a <- params$frac_trans_a
  b_I_su <- 1
  b_I_sd <- 1
  b_C <- params$frac_trans_c
  b_H <- params$frac_trans_h

  # rates of movement between compartments
  # rates are per class, not per sub-compartment
  gamma_L <- 1/params$time_e
  gamma_I_a <- 1/params$time_a
  gamma_I_su <- 1/params$time_su
  gamma_I_sd <- 1/params$time_sd
  gamma_C <- 1/params$time_c
  gamma_H <- 1/params$time_h

  R_e <- S * omega * (
    b_L / gamma_L +  (1 - a) * (
      q * ( 
        b_I_sd / (s * gamma_I_sd) 
        + b_C * s / gamma_C 
        + h * b_H / gamma_H 
      )
      + (1 - q) * b_I_su / gamma_I_su
    ) 
    + a * b_I_a / gamma_I_a 
  )
  
  return(R_e)
}
```

```{r data}
all_files <- list.files(path = here::here("output/current/"), pattern = ".csv")
param_files <- list.files(path = here::here("output/current/"), pattern = "params-natural.rds")
state_summaries <- tibble()
state_parameters <- tibble()
state_logliks <- tibble()
statedf <-readRDS(here::here("output/current", "statedf.rds"))
statevec <- gsub(".csv","",all_files)
allstates_pop <- statedf %>% filter(state_full %in% statevec) %>% pull(total_pop) %>% sum()

for(i in 1:length(all_files)) {
  do_file <- all_files[i]
  location <- sub(".csv", "", do_file)
  state_metadata <- statedf %>% filter(state_full == sub(".csv", "", do_file))
  state_pop <- state_metadata %>% pull(total_pop)
  # state_initR0 <- state_metadata %>% pull(initR0)
  # state_beta_s <- (state_initR0*.1)
  
  # tmp state params

  tmpparamfile <- here::here("output/current", param_files[i])
  # tmpparams <- readRDS(tmpparamfile)
  # tmp_loglik <- data.frame(location = unique(tmp$location),
  #                       log_lik = tmpparams["LogLik", 2])
  # rnms <- row.names(tmpparams)
  # tmpparams <- tmpparams %>%
  #   mutate(parameter = rnms) %>%
  #   filter(is_fitted == "yes") %>%
  #   dplyr::select(-is_fitted) %>%
  #   gather("key", "value", -parameter) %>%
  #   filter(key == "X1") %>%
  #   dplyr::select(-key) %>%
  #   mutate(location = unique(tmp$location)) %>%
  #   dplyr::select(location, value, parameter)
  # 
  # state_parameters <- bind_rows(state_parameters, tmpparams)
  # state_logliks <- bind_rows(state_logliks, tmp_loglik)
  
  statepars <- readRDS(tmpparamfile) 

  statepars_fixed <- statepars %>% 
    rownames_to_column(var = "param") %>% 
    dplyr::filter(is_fitted == "no") %>% 
    select(param,X1) %>% 
    pivot_wider(values_from = X1, names_from = param) %>% 
    select(-c(MIF_ID, LogLik, LogLik_SE))

  statepars_fitted <- statepars %>% 
    rownames_to_column(var = "param") %>% 
    dplyr::filter(is_fitted == "yes") %>% 
    select(param,X1) %>% 
    pivot_wider(values_from = X1, names_from = param)
  
  statepars_allmle <- bind_cols(statepars_fixed,statepars_fitted)
  
  # results

  tmpfile <- here::here("output/current", do_file)
  tmp <- read.csv(tmpfile) %>% mutate(date = as.Date(date))
  firstcasedate <- tmp$date %>% min()
  tmp <- tmp %>%
    filter(sim_type == "status_quo" | is.na(sim_type),
           variable %in% c("daily_cases", "daily_deaths", "daily_all_infections", 
                           "actual_daily_cases", "actual_daily_deaths",
                           "mobility_trend", "latent_trend", "combined_trend",
                           "cumulative_all_infections", "cumulative_deaths")) %>%
    dplyr::select(location, sim_type, period, date, variable, mean_value) %>% 
    pivot_wider(names_from = variable, values_from = mean_value) %>%
    # mutate(beta_s = state_beta_s) %>% 
    # calculate prevalence
    mutate(prevalence = daily_all_infections / (state_pop-cumulative_deaths)) %>% 
    # calculate omega
    mutate(omega = combined_trend * statepars_allmle$beta_s) %>% 
    # calculate mean S
    mutate(S = state_pop - cumulative_all_infections) %>% 
    mutate(susceptible_fraction = S / (state_pop - cumulative_deaths)) %>% 
    # calculate q
    mutate(q = get_q(t = as.numeric(date-firstcasedate), params = statepars_allmle)) %>%
    # calculate s
    mutate(s = get_s(t = as.numeric(date-firstcasedate), params = statepars_allmle)) %>%
    # calculate mean R_e
    mutate(R_e = getReff(S = S, 
                         omega = omega,
                         q = q,
                         s = s,
                         params = statepars_allmle)) %>%
    pivot_longer(cols = !c(location, sim_type, period, date), 
                 names_to = "variable", 
                 values_to = "mean_value",
                 values_drop_na = TRUE)

  state_summaries <- bind_rows(state_summaries, tmp)
}

# Key Dates
future <- state_summaries %>% filter(period == "Future") %>% pull(date) %>% range()
burnin <- c(state_summaries %>% pull(date) %>% min(), as.Date("2020-03-01"))
sample_period <- c(as.Date("2020-03-01"),as.Date("2020-12-31"))

```



```{r R_effective, fig.height=3, fig.width=7, eval = TRUE}
all_re <- state_summaries %>%
  filter(variable == "R_e") %>%
  dplyr::select(location, date, mean_value) %>%
  rename(R_e = mean_value) %>% 
  filter(date <= Sys.Date()) %>% 
  left_join(statedf, by=c("location" = "state_full")) %>% 
  mutate(relative_pop = total_pop / allstates_pop)

mean_re <- all_re %>% group_by(date) %>% summarise(R_e = mean(R_e))

weighted_mean_re <- all_re %>% group_by(date) %>% summarise(R_e = sum(R_e*relative_pop))

g_re <- ggplot(all_re, aes(x = date, y = R_e, color = location)) +
  geom_line(size = 0.5) +
  scale_y_continuous(limits = c(0,5)) +
  scale_color_viridis_d(option = "D", direction = -1, alpha = .6, end = .75) +
  # theme_dark(base_line_size = 0.5) +
  geom_line(data = weighted_mean_re, alpha = .6, size = 2, color = 'red') +
  # geom_rect(aes(xmin=future[1], xmax=future[2], ymin=0, ymax=1)) +
  annotate("rect", xmin = future[1], xmax = future[2], ymin = 0, ymax = 5, alpha = .2) +
  geom_hline(yintercept = 1) +
  theme_minimal() +
  guides(color = FALSE)

p_re <- g_re %>% plotly::ggplotly() %>% 
  plotly::layout(showlegend = FALSE,
                 # margin = list(l = 150),
                 yaxis = list(title = "R_e")
                 )

p_re$x$data[[51]]$text <- gsub('red','Mean of all states', p_re$x$data[[51]]$text)

p_re
```
